import os
import wget

test_path = os.path.join(os.getcwd(), "datasets/test/")
train_path = os.path.join(os.getcwd(), "datasets/train/")
os.makedirs(test_path, exist_ok=True)
os.makedirs(train_path, exist_ok=True)
base_url = "https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/"
train_files = ["train-labels-idx1-ubyte", "train-images-idx3-ubyte"]
test_files = ["t10k-labels-idx1-ubyte", "t10k-images-idx3-ubyte"]


def download(url, path, filename):
    if os.path.isfile(path + filename):
        return
    wget.download(url, path)


[download(base_url + x, train_path, x) for x in train_files]
[download(base_url + x, test_path, x) for x in test_files]